{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# PyTorch: Combining Multiple Shards of Data\n",
    "\n",
    "The `AISMultiShardStream` class facilitates combining multiple streams of data shards into one iterable dataset. It takes a list of `DataShard` objects as input, each representing a shard stream. When iterated over, it yields combined samples, where each sample is a tuple containing object bytes from each shard stream. This is particularly useful for scenarios where data is stored in separate shards. \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup client and necessary bucket"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import io\n",
    "import shutil\n",
    "from pathlib import Path\n",
    "import tarfile\n",
    "from aistore.client import Client\n",
    "from aistore.sdk.dataset.data_shard import DataShard\n",
    "from aistore.pytorch import AISMultiShardStream"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ais_url = os.getenv(\"AIS_ENDPOINT\", \"http://localhost:8080\")\n",
    "client = Client(ais_url)\n",
    "bucket = client.bucket(\"my-bck\").create(exist_ok=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Creating Shards and adding them to our Bucket"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Utility function to create a tar archive from a dictionary of file names and contents\n",
    "def create_archive(archive_name, content_dict):\n",
    "    directory = os.path.dirname(archive_name)\n",
    "    if not os.path.exists(directory):\n",
    "        os.makedirs(directory)\n",
    "\n",
    "    with tarfile.open(archive_name, \"w\") as tar:\n",
    "        for file_name, file_content in content_dict.items():\n",
    "            info = tarfile.TarInfo(name=file_name)\n",
    "            info.size = len(file_content)\n",
    "            tar.addfile(tarinfo=info, fileobj=io.BytesIO(file_content))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_path = Path().absolute().joinpath(\"multishard_example\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we will prepare two shards, each containing different types of files: one for ***text data (text_shard.tar)*** and the other for ***class labels (class_shard.tar)***. Each shard is a compressed archive containing multiple files. These shards will be combined later using `AISMultiShardStream`, enabling us to process both text and class data simultaneously as a single stream."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "text_shard_content_dict = {\n",
    "    \"file1.txt\": b\"Content of file one\",\n",
    "    \"file2.txt\": b\"Content of file two\",\n",
    "    \"file3.txt\": b\"Content of file three\",\n",
    "    \"file4.txt\": b\"Content of file four\",\n",
    "    \"file5.txt\": b\"Content of file five\",\n",
    "}\n",
    "text_shard_archive_name = \"text_shard.tar\"\n",
    "text_shard_archive_path = base_path.joinpath(text_shard_archive_name)\n",
    "create_archive(text_shard_archive_path, text_shard_content_dict)\n",
    "text_shard_obj = bucket.object(obj_name=text_shard_archive_name)\n",
    "text_shard_obj.get_writer().put_file(text_shard_archive_path)\n",
    "\n",
    "# Create a DataShard object for the text shard\n",
    "shard1 = DataShard(\n",
    "    client_url=ais_url,\n",
    "    bucket_name=\"my-bck\",\n",
    "    prefix=\"text_shard.tar\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class_shard_content_dict = {\n",
    "    \"file1.cls\": b\"1\",\n",
    "    \"file2.cls\": b\"2\",\n",
    "    \"file3.cls\": b\"3\",\n",
    "    \"file4.cls\": b\"4\",\n",
    "    \"file5.cls\": b\"5\",\n",
    "}\n",
    "class_shard_archive_name = \"class_shard.tar\"\n",
    "class_shard_archive_path = base_path.joinpath(class_shard_archive_name)\n",
    "create_archive(class_shard_archive_path, class_shard_content_dict)\n",
    "class_shard_obj = bucket.object(obj_name=class_shard_archive_name)\n",
    "class_shard_obj.get_writer().put_file(class_shard_archive_path)\n",
    "\n",
    "# Create a DataShard object for the class shard\n",
    "shard2 = DataShard(\n",
    "    client_url=ais_url,\n",
    "    bucket_name=\"my-bck\",\n",
    "    prefix=\"class_shard.tar\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Retriveing both shards in a single Stream"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = AISMultiShardStream(data_sources=[shard1, shard2])\n",
    "\n",
    "for data in dataset:\n",
    "    text_content, class_content = data\n",
    "    print(f\"Text: {text_content}, Class: {class_content}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Cleanup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Remove the shards from local disk\n",
    "try:\n",
    "    shutil.rmtree(str(base_path))\n",
    "except FileNotFoundError:\n",
    "    pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bucket.delete()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "my-python3-kernel",
   "language": "python",
   "name": "my-python3-kernel"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
